sortingcomponentsoverview

SortingComponents Overview Tutorial

Author: Paul Adkisson

The spikeinterface.sortingcomponents module breaks up the sorting process into distinct steps or components. This allows users to easily build their own spike sorting pipelines by mixing and matching existing components or swapping one (or more) out with their own version(s). We are also developing benchmarks to objectively evaluate performance for each component separately.

DISCLAIMER: This module is under heavy development -- signatures and behaviors may change form time to time. In this tutorial we used spikeinterface version 0.97.0.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import colormaps
from matplotlib.colors import LogNorm
from matplotlib.cm import ScalarMappable
%matplotlib inline
from pathlib import Path
import warnings
warnings.simplefilter("ignore") # warnings are filtered out (and manually removed) for improved readability

import spikeinterface.core as sc
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates

Download the Data

For this tutorial, we will use a Neuropixels dataset freely available on DANDI (sub-npI1_ses-20190413_behavior+ecephys.nwb)

Load, Preprocess & Visualize Data

For this tutorial we will breeze through the initial data reading, preprocessing, and visualization since these are all implemented in other spikeinterface modules. If you need a refresher on any of these steps, please see appropriate tutorials/documentation.

In [2]:
basepath = Path("/Volumes/T7/CatalystNeuro")
filepath = basepath / "TutorialDatasets/sub-npI1_ses-20190413_behavior+ecephys.nwb"
basepath = basepath / "SortingComponentsTutorial"
if not basepath.exists():
    basepath.mkdir()
recording = se.read_nwb(filepath)
print("recording =", recording)
fig, ax = plt.subplots(figsize=(10, 15))
_ = sw.plot_probe_map(recording, ax=ax)
recording = NwbRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 4732.300s
  file_path: /Volumes/T7/CatalystNeuro/TutorialDatasets/sub-npI1_ses-20190413_behavior+ecephys.nwb

This dataset has a probe with 384 channels on a single shank as visualized above.

Preprocessing should take ~6mins

In [3]:
# Handy kwargs for parallel computing
job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)

# We use only the first 10mins for speed
rec10mins = recording.frame_slice(start_frame=0, end_frame=recording.sampling_frequency*10*60)
print("rec10mins = ", rec10mins)
rec_filt = spre.bandpass_filter(rec10mins, freq_min=300., freq_max=6000.,
                                dtype='float32') # tridesclous template matching requires float32 dtype
rec_pproc = spre.common_reference(rec_filt, reference='global', operator='median')
# Cache output to perform the preprocessing computation up-front rather than lazily
preprocpath = basepath / "preproc"
if not preprocpath.exists():
    rec_pproc.save(folder=preprocpath, **job_kwargs)
rec_pproc = sc.load_extractor(preprocpath)
  
fig, ax = plt.subplots(figsize=(10, 5))
ax.set_ylabel("Channel Depth (um)")
_ = sw.plot_timeseries({"raw":rec10mins, "preprocessed":rec_pproc}, channel_ids=recording.channel_ids[50:60], ax=ax)
rec10mins =  FrameSliceRecording: 384 channels - 1 segments - 30.0kHz - 600.000s

Above we can see a short snippet of the raw and preprocessed voltage traces

Peak Detection

The classical first step of a spike sorting pipeline is detecting the prospective spikes, which is implemented in spikeinterface in spikeinterface.sortingcomponents.peak_detection.detect_peaks. Here we will compare a couple common approaches:

  • 'by_channel' which detects peaks independently on each channel
  • 'locally_exclusive' which detects peaks in a local neighborhood of a parameterized size

Peak Detection should take approximately 5mins / method

In [4]:
# To detect peaks properly we need the channel noise levels
noise_levels = sc.get_noise_levels(rec_pproc, return_scaled=False)
# handy common peak detection kwargs
detect_kwargs = dict(peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2, noise_levels=noise_levels)
detect_methods = ['by_channel', 'locally_exclusive']
detected_peaks = {}
for method in detect_methods:
    print(f"Method: {method}")
    peaks = detect_peaks(rec_pproc, method=method, **detect_kwargs, **job_kwargs)
    detected_peaks[method] = peaks
    
print("peaks.shape =", peaks.shape)
print("peaks.dtype =", peaks.dtype)
Method: by_channel
detect peaks:   0%|          | 0/600 [00:00<?, ?it/s]
Method: locally_exclusive
detect peaks:   0%|          | 0/600 [00:00<?, ?it/s]
peaks.shape = (822057,)
peaks.dtype = [('sample_ind', '<i8'), ('channel_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]

As we can see, the returned peaks are structured numpy arrays with fields

  • 'sample_ind' : the time index of each peak
  • 'channel_ind' : the channel index of each peak
  • 'amplitude' : the spike amplitude (uV)
  • 'segment_ind' : the segment index of each peak
In [5]:
print(f"# of peaks detected 'by channel': {detected_peaks['by_channel'].size:,.0f}")
print(f"# of peaks detected 'locally exclusive': {detected_peaks['locally_exclusive'].size:,.0f}")
# of peaks detected 'by channel': 1,878,215
# of peaks detected 'locally exclusive': 822,057

First, we note that detecting peaks by the 'locally exclusive' method returns far fewer potential spikes. This is because it excludes nearby channels from double-counting the same spike (within radius local_radius_um).

Now, Let's see how the amplitudes of the detected peaks differ.

In [6]:
hist_kwargs = dict(bins=40, histtype='step', density=True)
plt.figure()
for method in detect_methods:
    peaks = detected_peaks[method]
    plt.hist(np.log10(np.abs(peaks['amplitude'])), **hist_kwargs, label=method)
plt.xlabel("log10(Absolute Spike Amplitude)")
plt.ylabel("Frequency (% of spikes)")
_ = plt.legend(loc="upper right")

From the histogram above we can see a log-normal spike amplitude distribution, which is typical for neural recordings. Both methods detect a similar distribution of peaks, but 'locally_exclusive' is shifted slightly toward higher-amplitude peaks. This is likely due to the 'locally_exclusive' method omitting lower-amplitude peaks from channels adjacent to the principal channel.

Peak Localization

Many developed spike sorters require peak locations to appropriately cluster individual units localized in space. We will compare several peak localization methods implemented in spikeinterface.peak_localization.localize_peaks.

Center-of-mass localization takes ~1min

Monopolar Triangulation takes ~15mins

In [7]:
detect_method = "locally_exclusive"# we will use this method because it gave the best results
peaks = detected_peaks[detect_method]
# Handy kwargs for localizing peaks
localize_kwargs = dict(local_radius_um=150, ms_before=0.3, ms_after=0.6, max_distance_um=2_000)

localize_methods = ['center_of_mass', 'monopolar_triangulation']
localized_peaks = {}
for method in localize_methods:
    print(f"Method: {method}")
    localized_path = basepath / f"peak_locations_{method}.npy"
    if not localized_path.exists():
        peaklocs = localize_peaks(rec_pproc, peaks, method=method, **localize_kwargs, **job_kwargs)
        np.save(localized_path, peaklocs)
    else:
        peaklocs = np.load(localized_path)
    localized_peaks[method] = peaklocs

display("localized_peaks (center-of-mass) =", localized_peaks["center_of_mass"])
display("localized_peaks (monopolar triangulation) =", localized_peaks["monopolar_triangulation"])
Method: center_of_mass
Method: monopolar_triangulation
'localized_peaks (center-of-mass) ='
array([(35.44141803, 1415.00134048), (34.90595181, 1456.86399725),
       (35.44883201, 1504.80573544), ..., (31.41470648, 1381.40171115),
       (35.94058364, 1291.2247614 ), (32.77912362, 1224.25837716)],
      dtype=[('x', '<f8'), ('y', '<f8')])
'localized_peaks (monopolar triangulation) ='
array([( 67.27312732, 1435.29692795, 84.80512236, 3881.68338308),
       ( 77.99944805, 1403.71987755, 86.54212425, 4250.68337684),
       ( 59.99148254, 1403.22089483, 83.14741967, 4161.36286205), ...,
       ( 20.04527582, 1368.48668965,  1.85345906, 1385.65055837),
       ( 60.99929182, 1288.92955072, 20.73237984, 1551.61844719),
       (-35.1700467 , 1272.72233331, 71.09445816, 3104.1604341 )],
      dtype=[('x', '<f8'), ('y', '<f8'), ('z', '<f8'), ('alpha', '<f8')])

As we can see here, the returned peak locations are structured numpy arrays. The center-of-mass method performs a simple weighted average of peak amplitudes and returns fields

  • 'x' : Horizontal Location (breadth in um)
  • 'y' : Vertical Location (depth in um)

The monopolar triangulation method performs least-squares optimization on the peak amplitudes to distinguish between nearby low-amplitude spikes and far away high-amplitude spikes that might present the same amplitude on the principal channel. So, it adds two additional fields

  • 'z' : Perpendicular Distance to the probe (um)
  • 'alpha' : Estimated Amplitude at the source (uV)

Each entry corresponds to the same indexed peak in the input peaks array.

In [8]:
print(f"# of peaks localized 'center-of-mass': {localized_peaks['center_of_mass'].size:,.0f}")
print(f"# of peaks localized 'monopolar triangulation': {localized_peaks['monopolar_triangulation'].size:,.0f}")
# of peaks localized 'center-of-mass': 822,057
# of peaks localized 'monopolar triangulation': 822,057

Let's visualize the estimated locations on the probe!

In [9]:
localize_titles = {'center_of_mass': "Center of Mass", 'monopolar_triangulation':"Monopolar Triangulation"}
fig, ax = plt.subplots(1, len(localize_methods), figsize=(10, 20))
for i, method in enumerate(localize_methods):
    sw.plot_probe_map(rec_pproc, ax=ax[i])
    peaklocs = localized_peaks[method]
    ax[i].scatter(peaklocs['x'], peaklocs['y'], color="k", s=1, alpha=0.005)
    ax[i].set_title(localize_titles[method])

xlims = [-300, 300]
ax[1].set_xlim(xlims)
Out[9]:
(-300.0, 300.0)

From the 2D plot alone, we can see that monopolar triangulation creates distinct clouds of spikes that will be useful for clustering. The center-of-mass method tends to aggregate all the locations toward the center of the probe and cannot assign spike location outside the bounds of the probe.

Let's see what additional information is provided in the 2 extra fields from monopolar triangulation. For ease of visualization we will use 2d projections of each axis.

In [10]:
peaklocs = localized_peaks['monopolar_triangulation']
min_alpha = np.min(peaklocs['alpha'])
max_alpha = np.max(peaklocs['alpha'])
norm = LogNorm(vmin=min_alpha, vmax=max_alpha)
inferno = colormaps['inferno']
plot_kwargs = dict(c=peaklocs['alpha'], norm=norm, cmap=inferno, s=1, alpha=0.005)
mappable = ScalarMappable(norm=norm, cmap=inferno)

fig = plt.figure(figsize=(7.5, 10))
plt.subplots_adjust(wspace=0.5)
ax0 = plt.subplot2grid((6, 3), (0, 0), rowspan=6)
ax0.scatter(peaklocs['x'], peaklocs['y'], **plot_kwargs)
ax0.set_xlim(xlims)
ax0.set_ylim([0, 4000])
ax0.set_xlabel("x (um)")
ax0.set_ylabel("y (um)")

ax1 = plt.subplot2grid((6, 3), (0, 1), rowspan=6)
ax1.scatter(peaklocs['z'], peaklocs['y'], **plot_kwargs)
ax1.set_xlabel("z (um)")
ax1.set_ylabel("y (um)")
ax1.set_xlim([0, 600])
ax1.set_ylim([0, 4000]) 

ax2 = plt.subplot2grid((6, 3), (5, 2), rowspan=1)
ax2.scatter(peaklocs['x'], peaklocs['z'], **plot_kwargs)
ax2.set_xlim(xlims)
ax2.set_ylim([0, 600])
ax2.set_xlabel("x (um)")
ax2.set_ylabel("z (um)")

cax = plt.subplot2grid((6, 3), (0, 2), rowspan=5)
cbar = fig.colorbar(mappable, ax=cax, label="alpha (uV)")
cax.set_visible(False)

We can see a wide range of source spike amplitudes (note the logarithmic color scale), and cloud-like blobs of spikes in each projection. As expected, we generally observe higher-amplitude spikes at farther distances from the probe in each direction. Clearly, the extra time investment (~15mins vs ~1min) for monopolar triangulation provided important new information.

Peak Selection

Since we used a high-density Neuropixels probe, we detected 822,057 spikes in only a 10-minute recording. Unfortunately, clustering algorithms tend to be quite slow and struggle to handle large numbers of spikes (ex. kmeans has a time complexity of $\mathcal{O}(n^2)$).

To get around this problem, many modern spike sorters select a subset of these peaks to use for clustering. The way in which these peaks are sampled is a critical step to ensure that no units are lost during this process. We will compare several methods already implemented in spikeinterface.sortingcomponents.peak_selection.select_peaks

In [11]:
peaks = detected_peaks[detect_method]
localize_method = 'monopolar_triangulation' # we will use this method because it gave the best results
peaklocs = localized_peaks[localize_method]
select_methods = ["uniform", "smart_sampling_amplitudes", "smart_sampling_locations"]
select_kwargs = dict(n_peaks = 10_000, noise_levels=noise_levels, by_channel=False)
selected_peaks = {}
for method in select_methods:
    print(f"Method: {method}")
    if 'locations' in method:
        selected_peaks[method] = select_peaks(peaks, method=method, peaks_locations=peaklocs, seed=0,
                                              **select_kwargs)
    else:
        selected_peaks[method] = select_peaks(peaks, method=method, **select_kwargs)
print("peaks.shape=", selected_peaks['uniform'].shape)
print("peaks.dtype=", selected_peaks['uniform'].dtype)
Method: uniform
Method: smart_sampling_amplitudes
Method: smart_sampling_locations
peaks.shape= (10000,)
peaks.dtype= [('sample_ind', '<i8'), ('channel_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]

As expected, select_peaks returns a structured numpy array with the same fields as the peaks input, but subsampled to only n_peaks.

Let's see how the amplitudes of the sampled peaks compare to the base distribution

In [12]:
hist_kwargs = dict(histtype='step', bins=15, density=True)
plt.figure()
plt.hist(np.log10(np.abs(detected_peaks[detect_method]['amplitude'])), label="all spikes", **hist_kwargs)
for method in select_methods:
    plt.hist(np.log10(np.abs(selected_peaks[method]['amplitude'])), **hist_kwargs, label=method)
plt.xlabel("log10(Absolute Spike Amplitude)")
plt.ylabel("Frequency (% of spikes)")
_ = plt.legend(bbox_to_anchor=(1, 1))

All sampling methods roughly maintain the spike amplitude distribution. The 'smart_sampling_amplitudes' method appears to mildly over-sample the high-amplitude spikes at the expense of the low-amplitude spikes.

Let's look at the location distributions as well

In [13]:
# relocalizing peaks in each subset
selected_peaklocs = {}
for method in select_methods:
    print(f"Method: {method}")
    peaks = selected_peaks[method]
    peaklocs = localize_peaks(rec_pproc, peaks, method=localize_method,
                                          **localize_kwargs, **job_kwargs)
    selected_peaklocs[method] = peaklocs
Method: uniform
localize peaks:   0%|          | 0/600 [00:00<?, ?it/s]
Method: smart_sampling_amplitudes
localize peaks:   0%|          | 0/600 [00:00<?, ?it/s]
Method: smart_sampling_locations
localize peaks:   0%|          | 0/600 [00:00<?, ?it/s]
In [14]:
fig, ax = plt.subplots(1, len(select_methods)+1, figsize=(10, 20), sharey=True)
xlims = [-300, 300]
for i in range(len(select_methods)+1):
    sw.plot_probe_map(rec_pproc, ax=ax[i])

all_peaklocs = localized_peaks[localize_method]
ax[0].scatter(all_peaklocs['x'], all_peaklocs['y'], s=1, color="k", alpha=0.005)
ax[0].set_xlim(xlims)
ax[0].set_title("All Spike Locations")
method_titles = {"uniform":"Uniform",
                 "smart_sampling_amplitudes":"Smart Sampling \nAmplitudes",
                 "smart_sampling_locations":"Smart Sampling \nLocations"
                }
for i, method in enumerate(select_methods):
    peaklocs = selected_peaklocs[method]
    ax[i+1].scatter(peaklocs['x'], peaklocs['y'], color="C%i"%(i+1), s=2, alpha=0.1, label=method)
    ax[i+1].set_xlim(xlims)
    ax[i+1].set_title(method_titles[method])
    ax[i+1].set_ylabel("")
for i in [0, 1, 3]:
    ax[i].set_xlabel("")

From the plot above, we can see that all 3 peak selection methods capture a reasonable sample of the spike clouds in space. Both 'uniform' and 'smart_sampling_amplitudes' sample an approximately equal number of spikes from each distinct cloud in space. On the other hand 'smart_sampling_locations' favors the clouds in the lower half of the probe with higher spike counts at the expense of more sparse sampling of the diffuse cloud in the upper third of the probe (which is likely noise).

Clustering

Clustering is the central step of spike sorting, and remains fertile territory for new innovations to improve spike sorters. Historically, this step was split into feature extraction (ex. PCA on spike waveforms) and then clustering on those features (ex. kmeans). But, in the spirit of lazy processing we decided to combine these two steps into one. This decision allowed spikeinterface to compute feature extraction on-the-fly and avoid storage of large features.

Here we will compare a few clustering methods already implemented in spikeinterface.sortingcomponents.clustering

Clustering takes

  • <1min for 'position'
  • <1min for 'position_and_pca'
  • <1min for 'sliding_hdbscan'
In [15]:
select_method = 'smart_sampling_locations' # we will use this method because it gave the best results
peaks = selected_peaks[select_method]
peaklocs = selected_peaklocs[select_method]
method_kwargs = dict(peak_locations=peaklocs)
cluster_labels = {}
cluster_methods = ['position', 'position_and_pca', 'sliding_hdbscan']
for method in cluster_methods:
    print(f"Method: {method}")
    _, labels = find_cluster_from_peaks(rec_pproc, peaks, method_kwargs=method_kwargs, 
                                        method=method, **job_kwargs)
    cluster_labels[method] = labels
    
print("labels.shape =", labels.shape)
print("labels.dtype =", labels.dtype)
Method: position
Method: position_and_pca
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
Launching the local pca for splitting purposes
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
Method: sliding_hdbscan
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
labels.shape = (10000,)
labels.dtype = int64

As we can see, the returned peak labels are 1D numpy arrays with an integer label for each peak given to find_cluster_from_peaks. The ignored first argument contains the labels used for sorting, which are not used here.

Let's visualize the clustering output with their estimated locations. We expect unit clusters to localize together in space.

In [16]:
fig, ax = plt.subplots(1, len(cluster_methods), figsize=(10, 20), sharey=True)
for i in range(len(cluster_methods)):
    sw.plot_probe_map(rec_pproc, ax=ax[i])
cluster_titles = {"position":"Position",
                 "position_and_pca":"Position & PCA",
                 "sliding_hdbscan":"Sliding HDBSCAN"
                }
for i, method in enumerate(cluster_methods):
    labels = cluster_labels[method]
    for label in set(labels):
        peaklocs_clust = peaklocs[cluster_labels[method]==label]
        if label < 0: # labels < 0 correspond to noise clusters
            ax[i].scatter(peaklocs_clust['x'], peaklocs_clust['y'], color="k", s=2, alpha=0.05, label=label)
        else:
            ax[i].scatter(peaklocs_clust['x'], peaklocs_clust['y'], color=f"C{label}", s=2, alpha=1, label=label)

    ax[i].set_xlim(xlims)
    ax[i].set_title(cluster_titles[method])
    ax[i].set_ylabel("")
for i in [0, 2]:
    ax[i].set_xlabel("")

From the plots above, we can see that

  • Positional clustering produces spatial clusters roughly as we would expect with minimal spikes designated as noise
  • Position & PCA generated only one small cluster in the lower third of the probe with all other spikes designated as noise
  • Sliding HDBSCAN split the two large clusters in the middle of the probe into many smaller ones, and designated all the other spikes as noise

In order to determine which of these approaches are the most reasonable, let's inspect the unit waveforms.

Extract Waveforms & Visualize Templates

Waveform extraction and visualization are already implemented in spikeinterface in different modules, so we will employ those here.

First we convert the labels arrays to NpzSortingExtractor objects

In [17]:
cluster_sortings = {}
for method in cluster_methods:
    labels = cluster_labels[method]
    true_peaks = peaks[labels>=0]
    true_labels = labels[labels>=0] # omit noise clusters
    sorting = se.NumpySorting.from_times_labels(true_peaks['sample_ind'], true_labels, rec_pproc.sampling_frequency)
    sortpath = basepath / f"sorting_{method}.npz"
    se.NpzSortingExtractor.write_sorting(sorting, sortpath)
    sorting_extractor = se.NpzSortingExtractor(sortpath)
    cluster_sortings[method] = sorting_extractor
    print(cluster_sortings[method])
NpzSortingExtractor: 3 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/sorting_position.npz
NpzSortingExtractor: 1 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/sorting_position_and_pca.npz
NpzSortingExtractor: 26 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/sorting_sliding_hdbscan.npz

Next we convert recording-sorting pairs to WaveformExtractor objects

Extracting Waveforms takes ~1min / method

In [18]:
wavepaths = {method : basepath / f"waveforms_{method}" for method in cluster_methods}
cluster_wes = {}
for method in cluster_methods:
    print(f"Method: {method}")
    wavepath = wavepaths[method]
    sorting = cluster_sortings[method]
    we = sc.extract_waveforms(rec_pproc, sorting, wavepath, ms_before=1.5, ms_after=2, overwrite=True, **job_kwargs)
    cluster_wes[method] = we
    print(cluster_wes[method])
Method: position
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 3 units - 1 segments
  before:45 after:60 n_per_units:500
Method: position_and_pca
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 1 units - 1 segments
  before:45 after:60 n_per_units:500
Method: sliding_hdbscan
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 26 units - 1 segments
  before:45 after:60 n_per_units:500

Visualize waveforms

In [19]:
n_cols = 5
for method in cluster_methods:
    sorting = cluster_sortings[method]
    we = cluster_wes[method]
    sparsity = sc.ChannelSparsity.from_radius(we, radius_um=1) # only interested in pricipal channel waveforms
    unit_colors = {i:"C%i"%i for _, i in enumerate(sorting.unit_ids)}
    n_units = len(sorting.unit_ids)
    ncols = min(n_units, n_cols)
    fig = plt.figure(figsize=(2*ncols, 2*(1 + n_units//n_cols)), constrained_layout=True)
    fig.suptitle(cluster_titles[method], fontsize='x-large')
    sw.plot_unit_waveforms(we, unit_ids=sorting.unit_ids, sparsity=sparsity, max_spikes_per_unit=100,
                           alpha_waveforms=0.05, lw_templates=2, unit_colors=unit_colors,
                           ncols=min(ncols, n_units), figure=fig)
    axes = fig.axes
    for i, ax in enumerate(axes):
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Unit {i}")

From the waveforms, we can see that the two clusters from the 'position' method are likely undersplit, each representing multiple units. The 'position_and_pca' method seems to be missing many spikes that have been incorrectly designated as noise. The 'sliding_hdbscan' method picks up many potential units, some of which look very similar, so this method may be oversplitting. On the other hand, other units contain groups of spikes with distinct amplitudes/shapes which are likely under-split. To investigate further we could employ some of the many tools in spikeinterface.qualitymetrics to assess the sorting quality.

For this tutorial, though, we will simply proceed with 'sliding_hdbscan' since it detected the most units.

Template Matching

The final step of many spike sorting algorithms is to re-detect putative spikes using the templates built from previously clustered units. Typically, these algorithms try to explain the voltage traces as a linear sum of templates plus some residual noise. This allows spikes to be detected and automatically sorted even if they are co-occuring in time and space.

Here we will compare a few methods implemented in spikeinterface.sortingcomponents.template_matching:

  • 'naive': a naive template matching approach that simply uses the closest template
  • 'tridesclous' : the template matching method from tridesclous
  • 'circus' : the template matching method from spyking circus

Template matching takes

  • ~5min for 'naive'
  • ~7min for 'tridesclous'
  • ~13min for 'circus'
In [20]:
cluster_method = 'sliding_hdbscan' # we will use this method because it gave the most units
we = cluster_wes[cluster_method]
tmatch_methods = ['naive', 'tridesclous', 'circus']
method_kwargs = {'naive':dict(waveform_extractor=we),
                 'tridesclous':dict(waveform_extractor=we, noise_levels=noise_levels, num_closest=3),
                 'circus':dict(waveform_extractor=we, noise_levels=noise_levels)}
tmatch_spikes = {}
for method in tmatch_methods:
    print(f"Method: {method}")
    spikes = find_spikes_from_templates(rec_pproc, method=method, method_kwargs=method_kwargs[method])
    tmatch_spikes[method] = spikes
    print("spikes.shape =", spikes.shape)
    print("spikes.dtype =", spikes.dtype)
Method: naive
find spikes (naive):   0%|          | 0/600 [00:00<?, ?it/s]
spikes.shape = (37553,)
spikes.dtype = [('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]
Method: tridesclous
find spikes (tridesclous):   0%|          | 0/600 [00:00<?, ?it/s]
spikes.shape = (709620,)
spikes.dtype = [('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]
Method: circus
find spikes (circus):   0%|          | 0/600 [00:00<?, ?it/s]
spikes.shape = (58704,)
spikes.dtype = [('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]

As we can already see, the returned spikes object is a structured numpy array with the same 4 fields as the earlier peaks arrays, but with one additional 'cluster_ind' field, which corresponds to the cluster label for that spike.

Another note is that of the three methods, 'naive' finds the fewest spikes (66,594), 'circus' finds more (110,410) and 'tridesclous' finds the most (417,288)

To get a better sense of the differences between these outputs, let's inspect the unit waveforms

Extracting Waveforms takes ~2min/method

In [21]:
# Generating SortingExtractor Objects from spikes
print("...Generating Sortings...")
tmatch_sortings = {}
for method in tmatch_methods:
    print(f"Method: {method}")
    labels = tmatch_spikes[method]['cluster_ind']
    spiketime_inds = tmatch_spikes[method]['sample_ind']
    sorting = se.NumpySorting.from_times_labels(spiketime_inds, labels, rec_pproc.sampling_frequency,
                                                unit_ids=cluster_sortings['sliding_hdbscan'].unit_ids)
    sortpath = basepath / f"tmatch_sorting_{method}.npz"
    se.NpzSortingExtractor.write_sorting(sorting, sortpath)
    sorting_extractor = se.NpzSortingExtractor(sortpath)
    tmatch_sortings[method] = sorting_extractor
    print(tmatch_sortings[method])
    
# Generating WaveformExtractor Objects from SortingExtractors
print("...Generating WaveformExtractors...")
tmatch_wavepaths = {method : basepath / f"tmatch_waveforms_{method}" for method in tmatch_methods}
tmatch_wes = {}
for method in tmatch_methods:
    print(f"Method: {method}")
    wavepath = tmatch_wavepaths[method]
    sorting = tmatch_sortings[method]
    we = sc.extract_waveforms(rec_pproc, sorting, wavepath, ms_before=1.5, ms_after=2, overwrite=True, **job_kwargs)
    tmatch_wes[method] = we
    print(tmatch_wes[method])
...Generating Sortings...
Method: naive
NpzSortingExtractor: 26 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/tmatch_sorting_naive.npz
Method: tridesclous
NpzSortingExtractor: 26 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/tmatch_sorting_tridesclous.npz
Method: circus
NpzSortingExtractor: 26 units - 1 segments - 30.0kHz
  file_path: /Volumes/T7/CatalystNeuro/SortingComponentsTutorial/tmatch_sorting_circus.npz
...Generating WaveformExtractors...
Method: naive
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 26 units - 1 segments
  before:45 after:60 n_per_units:500
Method: tridesclous
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 26 units - 1 segments
  before:45 after:60 n_per_units:500
Method: circus
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
WaveformExtractor: 384 channels - 26 units - 1 segments
  before:45 after:60 n_per_units:500

Visualize Waveforms

In [22]:
tmatch_titles = {'naive':"Naive", 'tridesclous':"Tridesclous", 'circus':"SpyKing Circus"}
for method in tmatch_methods:
    sorting = tmatch_sortings[method]
    we = tmatch_wes[method]
    nonzero_ids = [unit_id for unit_id in sorting.unit_ids if we.get_waveforms(unit_id).size]
    sparsity = sc.ChannelSparsity.from_radius(we, radius_um=1) # only interested in pricipal channel waveforms
    unit_colors = {i:"C%i"%i for _, i in enumerate(sorting.unit_ids)}
    fig = plt.figure(figsize=(10, 10))
    sw.plot_unit_waveforms(we, unit_ids=nonzero_ids, sparsity=sparsity, max_spikes_per_unit=250,
                           alpha_waveforms=0.01, unit_colors=unit_colors, ncols=5, figure=fig)
    plt.suptitle(tmatch_titles[method], fontsize='x-large')
    axes = fig.axes
    for unit_id, ax in zip(nonzero_ids, axes):
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_title(f"Unit {unit_id}")

From these visualizations, we can see that all three methods produced many apparently well-isolated units.

The 'naive' method contains a few very noisy templates that may correspond either to pure noise spikes or under-detected units with too few spikes to properly resolve the template waveform. It also contains a few 'units' with 2 distinct spike amplitudes/timings that are strong candidates for splitting.

The 'tridesclous' method contains fewer of these problematic units, but still has some that are under-detected and others that are over-merged.

The 'circus' method also contains few problematic units, but does not detect as many as 'tridesclous' and has some very similar units with small numbers of spikes possibly indicating over-splitting.

All three template matching methods could potentially be improved by better initial clustering for the base templates.

Note: Due to randomness, exact unit numbers may vary.

Putting it all together

Now that we have explored each main component in the sortingcomponents module, we can combine our choices all together to create our own personalized spike sorting pipeline!

In [23]:
def cluster2waveform(recording, peaks, labels, path, job_kwargs=None): 
    '''Combines recording, peaks, and labels and repackages them into a waveform extractor
           with specified path'''
    if job_kwargs is None:
        job_kwargs = {}
    if not path.exists():
        path.mkdir()
        
    sorting = se.NumpySorting.from_times_labels(peaks['sample_ind'], labels, recording.sampling_frequency)
    sortpath = path / "sorting.npz"
    se.NpzSortingExtractor.write_sorting(sorting, sortpath)
    sorting_extractor = se.NpzSortingExtractor(sortpath)
    wavepath = path / "waveforms"
    we = sc.extract_waveforms(recording, sorting_extractor, wavepath,
                              ms_before=1.5, ms_after=2, overwrite=True, **job_kwargs)
    return we

def my_spikesorter(recording, path):
    # Setup
    job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True)
    recpath = path / "recording"
    if not recpath.exists():
        recording.save(folder=recpath, **job_kwargs)
    recording = sc.load_extractor(recpath)
    
    # Peak Detection
    print("...Detecting Peaks...")
    noise_levels = sc.get_noise_levels(recording, return_scaled=False)
    detect_method = 'locally_exclusive'
    detect_kwargs = dict(peak_sign='neg', detect_threshold=5, exclude_sweep_ms=0.2,
                         noise_levels=noise_levels, local_radius_um=50)
    peaks = detect_peaks(recording, method=detect_method, **detect_kwargs, **job_kwargs)
    
    # Peak Localization
    print("...Localizing Peaks...")
    localize_method = 'monopolar_triangulation'
    localize_kwargs = dict(local_radius_um=150, ms_before=0.3, ms_after=0.6)
    peaklocs = localize_peaks(recording, peaks, method=localize_method,
                              **localize_kwargs, **job_kwargs)
    
    # Peak Selection
    print("...Selecting Peaks...")
    select_method = 'smart_sampling_locations'
    select_kwargs = dict(n_peaks = 10_000, noise_levels=noise_levels, by_channel=False, seed=0)
    peaks = select_peaks(peaks, method=select_method, peaks_locations=peaklocs, **select_kwargs)
    peaklocs = localize_peaks(recording, peaks, method=localize_method,
                                  **localize_kwargs, **job_kwargs)
    
    # Clustering
    print("...Clustering...")
    cluster_method = 'sliding_hdbscan'
    cluster_kwargs = dict(peak_locations=peaklocs)
    _, labels = find_cluster_from_peaks(recording, peaks, method=cluster_method,
                                        method_kwargs=cluster_kwargs, **job_kwargs)
    
    # Curating
    print("...Curating...")
    true_peaks = peaks[labels>=0]
    true_labels = labels[labels>=0] # omit noise clusters
    we = cluster2waveform(recording, true_peaks, true_labels, path, job_kwargs=job_kwargs)
    
    # Tempate Matching
    print("...Template Matching...")
    tmatch_method = 'tridesclous'
    tmatch_kwargs = dict(waveform_extractor=we, noise_levels=noise_levels, num_closest=3)
    spikes = find_spikes_from_templates(recording, method=tmatch_method, method_kwargs=tmatch_kwargs)
    labels = spikes['cluster_ind']
    final_path = path / "final"
    we = cluster2waveform(recording, spikes, labels, final_path, job_kwargs=job_kwargs)
    return we

The full pipeline takes ~30mins

In [24]:
output_folder = basepath / "my_spikesorter"
we = my_spikesorter(rec_pproc, output_folder)
write_binary_recording with n_jobs = 4 and chunk_size = 30000
write_binary_recording:   0%|          | 0/600 [00:00<?, ?it/s]
...Detecting Peaks...
detect peaks:   0%|          | 0/600 [00:00<?, ?it/s]
...Localizing Peaks...
localize peaks:   0%|          | 0/600 [00:00<?, ?it/s]
...Selecting Peaks...
localize peaks:   0%|          | 0/600 [00:00<?, ?it/s]
...Clustering...
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
extract waveforms shared_memory with n_jobs = 8 and chunk_size = 6510
extract waveforms shared_memory:   0%|          | 0/2765 [00:00<?, ?it/s]
...Curating...
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]
...Template Matching...
find spikes (tridesclous):   0%|          | 0/600 [00:00<?, ?it/s]
extract waveforms memmap:   0%|          | 0/600 [00:00<?, ?it/s]

Now we have a full pipeline that can run end-to-end on the preprocessed recording object with a very compact definition (2-5 lines) to specify each step. We selected each step by assessing it individually, but we could very easily swap components out based on their contribution to the final sorting output.

We conclude this tutorial by visualizing the resulting single-unit spike trains from our spikesorting pipeline.

In [25]:
fig, ax = plt.subplots()
sw.plot_rasters(we.sorting, time_range=[0, 10], ax=ax)
_ = ax.set_ylabel("Unit #")
In [ ]:
 

spikeinterface peak localization

peak localization in spikeinterface

spikeinterface include several methods for unit or peak localization :

  • 'center_of_mass' : classic and fast localization. For instance herdingspikes use this method. It is quite accurate on squared MEA but have string artifact when units are on the border of the probe. So for linear probe this method give poor result on X axis.
  • 'monopolar_triangulation' with optimizer='least_square' This method is from Julien Boussard and Erdem Varol from the Paninski lab. This has been presented at NeurIPS see also here
  • 'monopolar_triangulation' with optimizer='minimize_with_log_penality' It is an improvement from the same team on the previous method not publish yet.

Here an example how to use.

In [1]:
%load_ext autoreload
%autoreload 2
In [4]:
from pathlib import Path

import spikeinterface.full as si

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)
In [5]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'

peak_folder.mkdir(exist_ok=True)
In [6]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_duration='1s',
    progress_bar=True,
)
In [11]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[11]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
In [12]:
fig, ax = plt.subplots()
si.plot_probe_map(rec, ax=ax)
ax.set_ylim(-150, 200)
Out[12]:
(-150.0, 200.0)

preprocess

This take 4 min for 30min of signals

In [13]:
if not preprocess_folder.exists():
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
    rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
    rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
rec_preprocessed = si.load_extractor(preprocess_folder)
write_binary_recording with n_jobs 40  chunk_size 30000
write_binary_recording: 100%|██████████| 1958/1958 [03:09<00:00, 10.34it/s]
In [14]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[14]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f894ea61af0>

estimate noise

In [15]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=np.arange(0,10, 1))
ax.set_title('noise across channel')
Out[15]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

This take 1min10s

In [16]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [18]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=5,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
detect peaks: 100%|██████████| 1958/1958 [01:11<00:00, 27.26it/s]
(4041179,)

localize peaks

We use 2 methods:

  • 'center_of_mass': 9 s
  • 'monopolar_triangulation' leagacy : 26min
  • 'monopolar_triangulation' log barrier : 16min
In [19]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [21]:
if not (peak_folder / 'peak_locations_center_of_mass.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='center_of_mass',
        method_kwargs={'local_radius_um': 100.},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_center_of_mass.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [00:08<00:00, 218.72it/s]
(4041179,)
In [36]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'least_square'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [26:42<00:00,  1.22it/s] 
(4041179,)
In [23]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy', peak_locations)
    print(peak_locations.shape)
localize peaks: 100%|██████████| 1958/1958 [16:15<00:00,  2.01it/s]
(4041179,)
In [24]:
# load back
# peak_locations = np.load(peak_folder / 'peak_locations_center_of_mass.npy')
# peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_legacy.npy')
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy')
print(peak_locations)
[( 18.81849235, 1782.84538913,  78.17532357, 1696.96239445)
 ( 31.90279769, 3847.75061369, 134.79844077, 1716.03155721)
 (-23.12038001, 2632.87834759,  87.76916268, 2633.62546695) ...
 ( 40.0839554 , 1977.83852796,  26.50998809, 1092.53885299)
 (-51.40036701, 1772.34521905, 170.65660676, 2533.03617278)
 ( 54.3813594 , 1182.28971165,  87.35020554, 1303.53392431)]

plot on probe

In [38]:
for name in ('center_of_mass', 'monopolar_triangulation_legacy', 'monopolar_triangulation_log_limit'):

    peak_locations = np.load(peak_folder / f'peak_locations_{name}.npy')

    probe = rec_preprocessed.get_probe()

    fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
    ax = axs[0]
    plot_probe(probe, ax=ax)
    ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_title(name)
    if 'z' in peak_locations.dtype.fields:
        ax = axs[1]
        ax.scatter(peak_locations['z'], peak_locations['y'], color='k', s=1, alpha=0.002)
        ax.set_xlabel('z')
    ax.set_ylim(1500, 2500)

plot peak depth vs time

In [39]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[39]:
(1300.0, 2500.0)

conclusion

spikeinterface clustering

spikeinterface clustering

The clustering step remains the central step of the spike sorting. Historically this step was separted into two distinct parts: feature reduction and clustering. In spikeinterface, we decided to regroup this two steps in the same module. This allows one to compute feature reduction on-the-fly and avoid long computations and storage of large features.

The clustering step takes the recording and detected (and optionally selected) peaks as input and returns a label for every peak.

At the moment, the implemenation is quite experimental. These methods have been implemented:

  • "position_clustering": use HDBSCAN on peak locations.
  • "sliding_hdbscan": clustering approach from tridesclous, with sliding spatial windows. PCA and HDBSCAN are run on local/sparse waveforms.
  • "position_pca_clustering": this method tries to use peak locations for a first clustering step and then perform further splits using PCA + HDBSCAN

Different methods may need different inputs (for instance some of them require need peak locations and some do not).

For this we will use a simulated with mearec dataset on 32 channel neuronexus like probe.

Here we will also use the select_peak() function to sub sample a smaller number of peak

In [1]:
# %matplotlib widget
%matplotlib inline
In [2]:
%load_ext autoreload
%autoreload 2
In [3]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [18]:
base_folder = Path('/mnt/data/sam/DataSpikeSorting/mearec_template_matching')
mearec_file = base_folder / 'recordings_collision_15cells_Neuronexus-32_1800s.h5'
rec_folder = base_folder /'Preprocessed_recording_15cells_Neuronexus-32_1800s'
peak_folder = base_folder / 'Peak_recording_15cells_Neuronexus'

clustering_path = base_folder / 'Clustering_recording_15cells_Neuronexus'

peak_folder.mkdir(exist_ok=True)
In [5]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_duration='1s',
    progress_bar=True,
)

Preprocess

In [6]:
# load already cache or compute
if rec_folder.exists():
    rec_preprocessed = si.load_extractor(rec_folder)
else:
    recording, gt_sorting = si.read_mearec(mearec_file)
    recording = si.bandpass_filter(recording, dtype='float32')
    recording = si.common_reference(recording)
    rec_preprocessed = recording.save(folder=rec_folder, n_jobs=20, chunk_size=30000, progress_bar=True)

estimate noise

In [7]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=15)
ax.set_title('noise across channel')
Out[7]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

In [8]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [9]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=10,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
(271519,)

select peaks

In [10]:
from spikeinterface.sortingcomponents.peak_selection import select_peaks
In [11]:
if not (peak_folder / 'some_peaks.npy').exists():
    some_peaks = select_peaks(peaks, method='uniform', select_per_channel=True, n_peaks=500, seed=None)
    np.save(peak_folder / 'some_peaks.npy', some_peaks)
some_peaks = np.load(peak_folder / 'some_peaks.npy')
print('some_peaks.size', some_peaks.size)
some_peaks.size 13424

localize peaks (on sub selection)

In [12]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [14]:
some_peak_locations = localize_peaks(rec_preprocessed, some_peaks,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        # method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'least_square'},
        **job_kwargs)
np.save(peak_folder / 'some_peak_locations.npy', some_peak_locations)
localize peaks: 100%|██████████| 1800/1800 [00:03<00:00, 513.78it/s]
In [15]:
some_peak_locations = np.load(peak_folder / f'some_peak_locations.npy')

probe = rec_preprocessed.get_probe()

fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
ax = axs[0]
si.plot_probe_map(rec_preprocessed, ax=ax)
ax.scatter(some_peak_locations['x'], some_peak_locations['y'], color='k', s=1, alpha=0.2)
ax.set_xlabel('x')
ax.set_ylabel('y')
if 'z' in some_peak_locations.dtype.fields:
    ax = axs[1]
    ax.scatter(some_peak_locations['z'], some_peak_locations['y'], color='k', s=1, alpha=0.2)
    ax.set_xlabel('z')
# ax.set_ylim(1500, 2500)

clustering (on sub selection)

In [19]:
from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks
In [43]:
method_kwargs = dict(
    peak_locations=some_peak_locations,
    hdbscan_params_spatial = {"min_cluster_size" : 20,  "allow_single_cluster" : True, 'metric' : 'l2'},
    probability_thr = 0,
    apply_norm=True,
    #~ debug=True,
    debug=False,
    tmp_folder=clustering_path,
    n_components_by_channel=4,
    n_components=4,
    job_kwargs = {"n_jobs" : 2, "chunk_size" : 30000, "progress_bar" : True},
    waveform_mode="shared_memory",
    #~ waveform_mode="memmap",
)

t0 = time.perf_counter()
possible_labels, peak_labels = find_cluster_from_peaks(rec_preprocessed, some_peaks, 
        method='position_pca_clustering', method_kwargs=method_kwargs)
t1 = time.perf_counter()
print('position_pca_clustering', t1 -t0)
extract waveforms shared_memory: 100%|██████████| 1800/1800 [00:00<00:00, 4997.14it/s]
extract waveforms shared_memory: 100%|██████████| 1800/1800 [00:00<00:00, 5033.09it/s]
position_pca_clustering 12.23708628397435
In [44]:
print(possible_labels)
[ 3  4  6 10 11 14 16 18 19 20 22 28 29 32 33]
In [45]:
import distinctipy
def plot_cluster_on_probe(rec, possible_labels, peak_labels):
    possible_colors = distinctipy.get_colors(possible_labels.size)

    colors = np.zeros((peak_labels.size, 3))
    for i, k in enumerate(possible_labels):
        mask = peak_labels == k
        colors[mask, :] = possible_colors[i]
    colors[mask, :] = possible_colors[i]
    

    fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
    ax = axs[0]
    si.plot_probe_map(rec, ax=ax)
    ax.scatter(some_peak_locations['x'], some_peak_locations['y'], s=1, c=colors, alpha=0.5)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    if 'z' in some_peak_locations.dtype.fields:
        ax = axs[1]
        ax.scatter(some_peak_locations['z'], some_peak_locations['y'], s=1, c=colors, alpha=0.5)
        ax.set_xlabel('z')
In [46]:
plot_cluster_on_probe(rec_preprocessed, possible_labels, peak_labels)

lets try another method

In [47]:
method_kwargs = dict(
)

t0 = time.perf_counter()
possible_labels, peak_labels = find_cluster_from_peaks(rec_preprocessed, some_peaks, 
        method='sliding_hdbscan', method_kwargs=method_kwargs)
t1 = time.perf_counter()
print('position_pca_clustering', t1 -t0)
position_pca_clustering 19.997920085676014
In [48]:
print(possible_labels)
[ 1  2  3  4  5  6  7  8  9 10 11 16 20]
In [49]:
plot_cluster_on_probe(rec_preprocessed, possible_labels, peak_labels)
In [ ]:
 
In [ ]:
 

spikeinterface motion estimation / correction

motion estimation in spikeinterface

In 2021,the SpikeInterface project has started to implemented sortingcomponents, a modular module for spike sorting steps.

Here is an overview or our progress integrating motion (aka drift) estimation and correction.

This notebook will be based on the open dataset from Nick Steinmetz published in 2021 "Imposed motion datasets" from Steinmetz et al. Science 2021 https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495

The motion estimation is done in several modular steps:

  1. detect peaks
  2. localize peaks:
  3. estimation motion:
    • rigid or non rigid
    • "decentralize" by Erdem Varol and Julien Boussard DOI : 10.1109/ICASSP39728.2021.9414145
    • "motion cloud" by Julien Boussard (not implemented yet)

Here we will show this chain:

  • detect peak > localize peaks with "monopolar_triangulation" > estimation motion "decentralize"
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from pathlib import Path

import spikeinterface.full as si

import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)

from probeinterface.plotting import plot_probe
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [3]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'

peak_folder.mkdir(exist_ok=True)
In [4]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_memory='10M',
    progress_bar=True,
)
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1957.191s
In [7]:
fig, ax = plt.subplots()
si.plot_probe_map(rec, ax=ax)
ax.set_ylim(-150, 200)
Out[7]:
(-150.0, 200.0)

preprocess

This take 4 min for 30min of signals

In [8]:
if not preprocess_folder.exists():
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
    rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
    rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
rec_preprocessed = si.load_extractor(preprocess_folder)
In [9]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[9]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f72adda0520>

estimate noise

In [13]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')
Out[13]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

This take 1min30s

In [11]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [14]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=5,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
(4041179,)

localize peaks

Here we chosse 'monopolar_triangulation' with log barrier

In [18]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
In [16]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000., 'optimizer': 'minimize_with_log_penality'},
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy', peak_locations)
    print(peak_locations.shape)
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy')
In [17]:
print(peak_locations)
[( 18.81849235, 1782.84538913,  78.17532357, 1696.96239445)
 ( 31.90279769, 3847.75061369, 134.79844077, 1716.03155721)
 (-23.12038001, 2632.87834759,  87.76916268, 2633.62546695) ...
 ( 40.0839554 , 1977.83852796,  26.50998809, 1092.53885299)
 (-51.40036701, 1772.34521905, 170.65660676, 2533.03617278)
 ( 54.3813594 , 1182.28971165,  87.35020554, 1303.53392431)]

plot on probe

In [22]:
fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
ax = axs[0]
si.plot_probe_map(rec_preprocessed, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
ax.set_xlabel('x')
ax.set_ylabel('y')
if 'z' in peak_locations.dtype.fields:
    ax = axs[1]
    ax.scatter(peak_locations['z'], peak_locations['y'], color='k', s=1, alpha=0.002)
    ax.set_xlabel('z')
    ax.set_xlim(0, 150)
ax.set_ylim(1800, 2500)
Out[22]:
(1800.0, 2500.0)

plot peak depth vs time

In [23]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
Out[23]:
(1300.0, 2500.0)

motion estimate : rigid with decentralized

In [25]:
from spikeinterface.sortingcomponents.motion_estimation import (
    estimate_motion,
    make_motion_histogram,
    compute_pairwise_displacement,
    compute_global_displacement
)
In [45]:
bin_um = 5
bin_duration_s=5.

motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations, 
    bin_um=bin_um,
    bin_duration_s=bin_duration_s,
    direction='y',
    weight_with_amplitude=False,
)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)
(392, 784) 393 785
In [32]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
im = ax.imshow(
    motion_histogram.T,
    interpolation='nearest',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(0, 30)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')
Out[32]:
Text(0, 0.5, 'depth[um]')

pariwise displacement from the motion histogram

In [39]:
pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement(motion_histogram, bin_um, method='conv2d', )
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)
In [40]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(
    pairwise_displacement,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[40]:
<matplotlib.colorbar.Colorbar at 0x7f7238013400>

estimate motion (rigid) from the pairwise displacement

In [43]:
motion = compute_global_displacement(pairwise_displacement)

motion = compute_global_displacement(pairwise_displacement,convergence_method='gradient_descent',)
# motion = compute_global_displacement(pairwise_displacement, pairwise_displacement_weight=pairwise_displacement_weight, convergence_method='lsqr_robust',)
In [47]:
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion)
Out[47]:
[<matplotlib.lines.Line2D at 0x7f7238624a60>]

motion estimation with one unique funtion

Internally estimate_motion() does:

  • make_motion_histogram()
  • compute_pairwise_displacement()
  • compute_global_displacement()
In [58]:
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.widgets import plot_pairwise_displacement, plot_displacement
In [59]:
method='decentralized_registration'
method_kwargs = dict(

     pairwise_displacement_method='conv2d',
    # convergence_method='gradient_descent',
    convergence_method='lsqr_robust',
    
)

# method='decentralized_registration'
# method_kwargs = dict(
#     pairwise_displacement_method='phase_cross_correlation',
#     convergence_method='lsqr_robust',
# )


motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=10.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=None,
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
)
100%|██████████| 392/392 [00:09<00:00, 40.71it/s]
In [60]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)
Out[60]:
<spikeinterface.widgets.drift.PairwiseDisplacementWidget at 0x7f72427cf4c0>
In [61]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)
Out[61]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f72384dd8b0>
In [62]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)
Out[62]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f72381deca0>
In [ ]:
 

motion estimation non rigid

In [64]:
# method='decentralized_registration'
# method_kwargs = dict()
#     pairwise_displacement_method='conv2d',
#     convergence_method='gradient_descent',
# )

method='decentralized_registration'
method_kwargs = dict(
    pairwise_displacement_method='conv2d',
    convergence_method='lsqr_robust',
)


# method='decentralized_registration'
# method_kwargs = dict(
#     pairwise_displacement_method='phase_cross_correlation',
#     convergence_method='lsqr_robust',
# )


motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=5.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=dict(bin_step_um=200, signam=3),
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
)
100%|██████████| 392/392 [00:18<00:00, 21.18it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.08it/s]
100%|██████████| 392/392 [00:18<00:00, 21.04it/s]
100%|██████████| 392/392 [00:20<00:00, 19.49it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.47it/s]
100%|██████████| 392/392 [00:18<00:00, 21.08it/s]
100%|██████████| 392/392 [00:20<00:00, 19.57it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.12it/s]
100%|██████████| 392/392 [00:20<00:00, 19.59it/s]
100%|██████████| 392/392 [00:18<00:00, 21.07it/s]
100%|██████████| 392/392 [00:20<00:00, 19.54it/s]
100%|██████████| 392/392 [00:18<00:00, 21.04it/s]
100%|██████████| 392/392 [00:19<00:00, 19.60it/s]
100%|██████████| 392/392 [00:18<00:00, 21.09it/s]
100%|██████████| 392/392 [00:19<00:00, 19.61it/s]
100%|██████████| 392/392 [00:18<00:00, 21.25it/s]
In [65]:
fig, ax = plt.subplots()
for win in extra_check['non_rigid_windows']:
    ax.plot(win, extra_check['spatial_hist_bins'][:-1])
In [66]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)
Out[66]:
<spikeinterface.widgets.drift.PairwiseDisplacementWidget at 0x7f722db68dc0>
In [67]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)
Out[67]:
<spikeinterface.widgets.drift.DisplacementWidget at 0x7f722b6d6d60>
In [69]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)
ax.set_ylim(0, 2000)
Out[69]:
(0.0, 2000.0)
In [70]:
fig, ax = plt.subplots()
ax.plot(temporal_bins, motion)
Out[70]:
[<matplotlib.lines.Line2D at 0x7f722b466790>,
 <matplotlib.lines.Line2D at 0x7f722b4667c0>,
 <matplotlib.lines.Line2D at 0x7f722b4668e0>,
 <matplotlib.lines.Line2D at 0x7f722b466a00>,
 <matplotlib.lines.Line2D at 0x7f722b4577f0>,
 <matplotlib.lines.Line2D at 0x7f722b457820>,
 <matplotlib.lines.Line2D at 0x7f722b466c40>,
 <matplotlib.lines.Line2D at 0x7f722b466d60>,
 <matplotlib.lines.Line2D at 0x7f722b466e80>,
 <matplotlib.lines.Line2D at 0x7f722b466fa0>,
 <matplotlib.lines.Line2D at 0x7f722b4412e0>,
 <matplotlib.lines.Line2D at 0x7f722b3f0100>,
 <matplotlib.lines.Line2D at 0x7f722b3f0310>,
 <matplotlib.lines.Line2D at 0x7f722b3f0430>,
 <matplotlib.lines.Line2D at 0x7f722b3f0550>,
 <matplotlib.lines.Line2D at 0x7f722b3f0670>,
 <matplotlib.lines.Line2D at 0x7f722b3f0790>,
 <matplotlib.lines.Line2D at 0x7f722b3f08b0>,
 <matplotlib.lines.Line2D at 0x7f722b3f09d0>,
 <matplotlib.lines.Line2D at 0x7f722b3f0af0>]
In [71]:
fig, ax = plt.subplots()
im = ax.imshow(motion.T,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    # extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
Out[71]:
<matplotlib.colorbar.Colorbar at 0x7f722b336a30>
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 

spikeinterface peak detection

peak detection in spikeinterface

Author : Samuel Garcia

spikeinterface implements several method for peak detection.

peak detection can be used:

  1. as a first step for spike sorting chain
  2. as a first step for estimating motion (aka drift)

Here we will illustrate how this work and also how in conjonction of the preprocessing module we can compute this detection with or without caching the preprocessed traces on the disk.

This example will be run on neuropixel 1 and neuropixel 2 recorded by Nick Steinmetz here.

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
# %matplotlib widget
%matplotlib inline
In [4]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import spikeinterface.full as si

open dataset

In [4]:
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')

dataset_folder = base_folder / 'dataset1/NP2'
preprocess_folder_bin = base_folder / 'dataset1_NP2_preprocessed_binary'
preprocess_folder_zarr = base_folder / 'dataset1_NP2_preprocessed_zarr'
In [5]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
Out[5]:
SpikeGLXRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [6]:
fig, ax = plt.subplots(figsize=(7, 20))
si.plot_probe_map(rec, with_channel_ids=True, ax=ax)
ax.set_ylim(-150, 200)
Out[6]:
(-150.0, 200.0)

preprocess

Here we will apply filetring + CMR

And to demonstrate the flexibility we will on work on 3 objects:

  • the lazy object rec_preprocessed
  • the cached object in binary format rec_preprocessed_cached_binary
  • the cached object in zarr format rec_preprocessed_cached_zarr

Caching to binary take Caching to zarr take

In [7]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=20,
    chunk_duration='1s',
    progress_bar=True,
)
In [8]:
# create the lazy object
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
print(rec_preprocessed)
CommonReferenceRecording: 384 channels - 1 segments - 30.0kHz - 1956.954s
In [9]:
# if not exists yet cache to binary
if preprocess_folder_bin.exists():
    rec_preprocessed_cached_binary = si.load_extractor(preprocess_folder_bin)
else:
    # cache to binary
    rec_preprocessed_cached_binary = rec_preprocessed.save(folder=preprocess_folder_bin, format='binary', **job_kwargs)
write_binary_recording with n_jobs 20  chunk_size 30000
write_binary_recording: 100%|██████████| 1957/1957 [03:50<00:00,  8.49it/s]
In [10]:
print(rec_preprocessed_cached_binary)
BinaryRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s
  file_paths: ['/mnt/data/sam/DataSpikeSorting/imposed_motion_nick/dataset1_NP2_preprocessed_binary/traces_cached_seg0.raw']
In [11]:
if preprocess_folder_zarr.exists():
    rec_preprocessed_cached_zarr = si.load_extractor(preprocess_folder_zarr)
else:
    # cache to zarr
    rec_preprocessed_cached_zarr = rec_preprocessed.save(zarr_path=preprocess_folder_zarr,  format='zarr', **job_kwargs)
Using default zarr compressor: Blosc(cname='zstd', clevel=5, shuffle=BITSHUFFLE, blocksize=0). To use a different compressor, use the 'compressor' argument
write_zarr_recording with n_jobs 20  chunk_size 30000
write_zarr_recording: 100%|██████████| 1957/1957 [03:36<00:00,  9.04it/s]
Skipping field contact_plane_axes: only 1D and 2D arrays can be serialized
In [16]:
print(rec_preprocessed_cached_zarr)
ZarrRecordingExtractor: 384 channels - 1 segments - 30.0kHz - 1956.954s

show some traces

In [9]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
Out[9]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f003e94c880>

estimate noise

We need some estimation of the noise.

Very important : we must estimate the noise with return_scaled=False because detect_peaks() will work on raw data (int16 very often)

In [39]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')
Out[39]:
Text(0.5, 1.0, 'noise across channel')

detect peaks

we have 2 methods in spikeinetrface with are done with numba:

  • 'by_channel' : peaks are detected on each channel indepandantly
  • 'locally_exclusive' : if a units fire on several channel the best peak on the best channel is kept This is controlle by local_radius_um
In [34]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
In [40]:
peaks = detect_peaks(rec_preprocessed,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [02:09<00:00, 15.09it/s]
(2531770,)

compare compute time with cached version

When we detect peak on the lazy object. Every trace chunk is loaded processed and then peak are detected on it.

When we detect peak on cached version the trace chunk is read from the save version

In [41]:
peaks = detect_peaks(rec_preprocessed_cached_binary,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:30<00:00, 21.55it/s]
(2528737,)
In [42]:
peaks = detect_peaks(rec_preprocessed_cached_zarr,
                     method='locally_exclusive',
                     local_radius_um=100,
                     peak_sign='neg',
                     detect_threshold=5,
                     n_shifts=5,
                     noise_levels=noise_levels,
                    **job_kwargs)
print(peaks.shape)
detect peaks: 100%|██████████| 1957/1957 [01:28<00:00, 22.23it/s]
(2528737,)

Conlusion

Running peak detection on lazy vs cached version is an important choice.

detect_peak() is a bit faster on cahed version (1:30) than lazy version (2:00)

But the total time of save() + detect_peak() is slower (3:30 + 1:30 = 5:00) !!!

Here writing to disk is clearly a waste on time.

So the benefit of caching totally depend:

  1. on the complexity of the preprocessing chain
  2. writting disk capability
  3. how many the preprocessed recording will be cunsumed!!!

spikeinterface template matching

spikeinterface template matching

Template matching is the final step used in many tools (kilosort, spyking-circus, yass, tridesclous, hdsort...)

In this step, from a given catalogue (aka dictionnary) of template (aka atoms), algorithms explain traces as a linear sum of template plus residual noise.

We have started to implement some template matching procedure in spikeinterface.

Here a small demo and also some benchmark to compare performance of theses algos.

For this we will use a simulated with mearec dataset on 32 channel neuronexus like probe. Then we will compute the true template using the true sorting. Theses true templates will be used for diffrents method. And then we will apply comparison to ground truth procedure to estimate only this step.

In [1]:
# %matplotlib widget
%matplotlib inline
In [2]:
%load_ext autoreload
%autoreload 2
In [3]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
/home/samuel.garcia/.virtualenvs/py38/lib/python3.8/site-packages/datalad/cmd.py:375: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  new_loop = True
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
/home/samuel.garcia/Documents/SpikeInterface/spikeinterface/spikeinterface/core/datasets.py:12: RuntimeWarning: coroutine 'run_async_cmd' was never awaited
  HAVE_DATALAD = False
RuntimeWarning: Enable tracemalloc to get the object allocation traceback
In [4]:
base_folder = Path('/mnt/data/sam/DataSpikeSorting/mearec_template_matching')
mearec_file = base_folder / 'recordings_collision_15cells_Neuronexus-32_1800s.h5'
wf_folder = base_folder / 'Waveforms_recording_15cells_Neuronexus-32_1800s'
rec_folder = base_folder /'Preprocessed_recording_15cells_Neuronexus-32_1800s'

open and preprocess

In [5]:
# load already cache or compute
if rec_folder.exists():
    recording = si.load_extractor(rec_folder)
else:
    recording, gt_sorting = si.read_mearec(mearec_file)
    recording = si.bandpass_filter(recording, dtype='float32')
    recording = si.common_reference(recording)
    recording = recording.save(folder=rec_folder, n_jobs=20, chunk_size=30000, progress_bar=True)

construct true templates

In [6]:
_, gt_sorting = si.read_mearec(mearec_file)
recording = si.load_extractor(rec_folder)
In [7]:
we = si.extract_waveforms(recording, gt_sorting, wf_folder, load_if_exists=True,
                           ms_before=2.5, ms_after=3.5, max_spikes_per_unit=500,
                           n_jobs=20, chunk_size=30000, progress_bar=True)
print(we)
WaveformExtractor: 32 channels - 15 units - 1 segments
  before:75 after:105 n_per_units:500
In [8]:
metrics = si.compute_quality_metrics(we, metric_names=['snr'], load_if_exists=True)
metrics
Out[8]:
snr
#0 42.573563
#1 23.475538
#2 11.677200
#3 8.544864
#4 61.134110
#5 49.281887
#6 31.793837
#7 36.275745
#8 12.932632
#9 39.769770
#10 8.230338
#11 14.968547
#12 12.002127
#13 12.905783
#14 20.285872

run several method of template matching

A unique function is used for that find_spikes_from_templates()

In [9]:
from spikeinterface.sortingcomponents.template_matching import find_spikes_from_templates
In [10]:
# Some method need teh noise level (for internal detection)
noise_levels = si.get_noise_levels(recording, return_scaled=False)
noise_levels
Out[10]:
array([3.9969404, 3.9896376, 3.8046541, 3.5555122, 3.3091464, 3.257736 ,
       3.6201818, 3.9503036, 4.079712 , 4.2103205, 3.8557687, 3.9278026,
       3.8464408, 3.651188 , 3.4105062, 3.2170172, 3.3981993, 3.7377162,
       3.9932737, 4.1710896, 4.2710056, 4.296086 , 3.7716963, 3.7748668,
       3.6391177, 3.4687228, 3.3020885, 3.3594728, 3.6073673, 3.8444421,
       4.0852304, 4.234068 ], dtype=float32)
In [11]:
## this method support parallel computing
job_kwargs = dict(
    n_jobs=40,
    chunk_size=30000,
    progress_bar=True
)
In [16]:
# lets build dict for handling parameters
methods = {}
methods['naive'] =  ('naive', 
                    {'waveform_extractor' : we})
methods['tridesclous'] =  ('tridesclous',
                           {'waveform_extractor' : we,
                            'noise_levels' : noise_levels,
                            'num_closest' :3})
methods['circus'] =  ('circus',
                      {'waveform_extractor' : we,
                       'noise_levels' : noise_levels})
methods['circus-omp'] =  ('circus-omp',
                          {'waveform_extractor' : we,
                           'noise_levels' : noise_levels})


spikes_by_methods = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = find_spikes_from_templates(recording, method=method, method_kwargs=method_kwargs, **job_kwargs)
    spikes_by_methods[name] = spikes
find spikes (naive): 100%|██████████| 1800/1800 [00:06<00:00, 294.34it/s]
find spikes (tridesclous): 100%|██████████| 1800/1800 [00:06<00:00, 277.50it/s]
[1] compute overlaps: 100%|██████████| 180/180 [00:00<00:00, 978.20it/s]
[2] compute amplitudes: 100%|██████████| 15/15 [00:01<00:00,  9.17it/s]
find spikes (circus): 100%|██████████| 1800/1800 [00:04<00:00, 386.06it/s]
find spikes (circus-omp): 100%|██████████| 1800/1800 [00:28<00:00, 63.54it/s]
In [17]:
## the output of every method is a numpy array with a complex dtype

spikes = spikes_by_methods['tridesclous']
print(spikes.dtype)
print(spikes.shape)
print(spikes[:5])
[('sample_ind', '<i8'), ('channel_ind', '<i8'), ('cluster_ind', '<i8'), ('amplitude', '<f8'), ('segment_ind', '<i8')]
(234977,)
[( 59,  0,  4, 1., 0) (309, 21,  8, 1., 0) (371, 13,  3, 1., 0)
 (623, 30, 14, 1., 0) (713, 31, 13, 1., 0)]

check performances method by method

For this:

  1. we transform the spikes vector into a sorting object
  2. use the compare_sorter_to_ground_truth() function to compute performances
  3. plot agreement matrix
  4. plot accuracy vs snr
  5. plot collision vs similarity

Note:

  • as we provide the true template list every matrix is supposed to be squared!!! The performances are can be seen on the diagonal. A perfect matching is supposed to have only ones on the diagonal.
  • The dataset here is one the dataset used in collision paper We can also make a fine benchmark on inspecting collision.
In [18]:
# load metrics for snr on true template
metrics = we.load_extension('quality_metrics').get_metrics()
In [20]:
templates = we.get_all_templates()

comparisons = {}
for name,  (method, method_kwargs) in methods.items():
    spikes = spikes_by_methods[name]

    sorting = si.NumpySorting.from_times_labels(spikes['sample_ind'], spikes['cluster_ind'], recording.get_sampling_frequency())
    print(sorting)

    comp = si.compare_sorter_to_ground_truth(gt_sorting, sorting)
    

    fig, axs = plt.subplots(ncols=2)
    si.plot_agreement_matrix(comp, ax=axs[0])
    si.plot_sorting_performance(comp, metrics, performance_name='accuracy', metric_name='snr', ax=axs[1], color='g')
    si.plot_sorting_performance(comp, metrics, performance_name='recall', metric_name='snr', ax=axs[1], color='b')
    si.plot_sorting_performance(comp, metrics, performance_name='precision', metric_name='snr', ax=axs[1], color='r')
    axs[0].set_title(name)
    axs[1].set_ylim(0.8, 1.1)
    axs[1].legend(['accuracy', 'recall', 'precision'])
    
    comp = si.CollisionGTComparison(gt_sorting, sorting)
    comparisons[name] = comp
    fig, ax = plt.subplots()
    si.plot_comparison_collision_by_similarity(comp, templates, figure=fig)
    fig.suptitle(name)

plt.show()
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz
NumpySorting: 15 units - 1 segments - 30.0kHz

comparison of methods 2 by 2

In [ ]:
names = list(comparisons.keys())
n = len(names)

for r, name0 in enumerate(names):
    for c, name1 in enumerate(names):
        if r<=c:
            continue

        fig, ax = plt.subplots()
        val0 = comparisons[name0].get_performance()['accuracy']
        val1 = comparisons[name1].get_performance()['accuracy']
        ax.scatter(val0, val1)
        ax.set_xlabel(name0)
        ax.set_ylabel(name1)
        ax.plot([0,1], [0, 1], color='k')
        ax.set_title('accuracy')
        ax.set_xlim(0.6, 1)
        ax.set_ylim(0.6, 1)

conclusion

  • tridesclous and circus-omp are clear winner for performances
  • tridesclous is the fastest
  • Improvement must be done because performances are far to be perfect!!!
In [ ]:
 

spikeinterface destripe

destripe processsing in spikeinterface

Author : Samuel Garcia

Olivier Winter has develop for IBL a standard pre-processing chain in the ibllib to clean the traces before spike sorting. See this

This procesdure is called "destripe". This procedure remove artefact that are present on all channels (common noise)

The main idea is to have this:

  1. filter
  2. align sample (phase shift
  3. remove common noise
  4. apply spatial filter and bad channel interpolation

Except step 4., all other steps are available in spikeinterface

spikeinterface.toolkit.preprocessing propose some class and function to build what we call a lazy chain of processing.

Here an example with 4 files nicely given by Oliver Winter to illustarte the spikeinterface implementation of this destripe procedure.

In [42]:
# %matplotlib widget
%matplotlib inline
In [19]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [20]:
from pathlib import Path

import time

import numpy as np
import matplotlib.pyplot as plt


import spikeinterface.full as si
In [21]:
base_folder = Path('/media/samuel/dataspikesorting/DataSpikeSorting/olivier_destripe/')

folder1 = base_folder / '4c04120d-523a-4795-ba8f-49dbb8d9f63a'
folder2 = base_folder / '68f06c5f-8566-4a4f-a4b1-ab8398724913'
folder3 = base_folder / '8413c5c6-b42b-4ec6-b751-881a54413628'
folder4 = base_folder / 'f74a6b9a-b8a5-4c80-9c30-7dd4cdbb48c0'
data_folders = [folder1, folder2, folder3, folder4]

Build the prprocessing chain

In spike interface we have:

  • bandpass_filter()
  • common_reference(): this remove common noise (global or local) by substraction of median (or average)
  • phase_shift(): this compensate the ADC shift across channel by applying a reverse in FFT transform.

That can be combined to get ore or less the same result than the "destripe".

Here we will compare 2 preprocessing:

  1. filter > cmr
  2. filter > phase_shift > cmr

The step 4. (kfilter) is not implemented yet but this should be done soon.

In [38]:
# lets have a function that build the chain and plot intermediate results

def preprocess_steps(rec, time_range=None, clim=(-80, 80), figsize=(15, 10)):
    
    # chain 1. : filter + cmr
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000., dtype='float32')
    rec_cmr = si.common_reference(rec_filtered, reference='global', operator='median')
    
    # chain 2.. : filter + phase_shift + cmr
    rec_pshift = si.phase_shift(rec_filtered)
    rec_cmr2 = si.common_reference(rec_pshift, reference='global', operator='median')
    
    
    # rec
    fig, axs = plt.subplots(ncols=5, sharex=True, sharey=True, figsize=figsize)
    
    ax = axs[0]
    ax.set_title('raw')
    si.plot_timeseries(rec, ax=ax,  with_colorbar=False) # clim=clim,
    
    # filter
    
    ax = axs[1]
    ax.set_title('filtered')
    si.plot_timeseries(rec_filtered, ax=axs[1], clim=clim, with_colorbar=False)
    
    # filter + cmr
    
    # rec_preprocessed
    ax = axs[2]
    ax.set_title('filtered + cmr')
    si.plot_timeseries(rec_cmr, ax=axs[2], clim=clim, with_colorbar=False)
    
    # filter + phase_shift
    
    ax = axs[3]
    ax.set_title('filtered + phase_shift')
    si.plot_timeseries(rec_pshift, ax=ax, clim=clim, with_colorbar=False)
    
    # filtered + phase_shift + cmr
    
    ax = axs[4]
    ax.set_title('filtered + phase_shift + cmr')
    si.plot_timeseries(rec_cmr2, ax=ax, clim=clim, with_colorbar=True)

    # optionally a time range can be given
    if time_range is not None:
        ax.set_xlim(*time_range)

dataset 1

In [30]:
rec = si.read_cbin_ibl(folder1)
preprocess_steps(rec)
In [31]:
# zoom on a stripe
preprocess_steps(rec, time_range=(0.95, 0.97))

dataset 2

In [32]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec)
In [33]:
rec = si.read_cbin_ibl(folder2)
preprocess_steps(rec, time_range=(0.2, .3))

dataset3

In [34]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec)
In [35]:
rec = si.read_cbin_ibl(folder3)
preprocess_steps(rec, time_range=(0.797, .801))

dataset 4

In [39]:
rec = si.read_cbin_ibl(folder4)
preprocess_steps(rec, clim=(-50, 50))
In [41]:
preprocess_steps(rec, clim=(-50, 50), time_range=(0.368, .375))

conlusion

Here we demonstrate how to use the modular way of building a preprocessing chain directly in spikeinterface. This is particularly usefull because:

  1. the same preprocessing can be apply for diffrent sorters
  2. The preprocessing can cached in parralel using rec.save(...) in binary or zarr format
  3. Every steps can be parameterized depending the in put dataset and compute ressource available.

Collision paper spike sorting performance

Spike sorting performance against spike collisions (figure 2-3-5)

In this notebook, we describe how to generate the figures for all the sudies, i.e. for all rate and correlation levels, in a systematic manner. However, while by default the figures were saved as .pdf, here we will modify the ranges of rates and correlations to display only a single figures. Feel free to modify the scripts in order to display only a single figures

In [1]:
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import gridspec

import MEArec as mr
import spikeinterface.full as si
In [2]:
study_base_folder = Path('../data/study/')

Plot global spike sorting performance (Figure 2)

In [1]:
res = {}

rate_levels = [5]
corr_levels = [0]

for rate_level in rate_levels:
    for corr_level in corr_levels:

        fig = plt.figure(figsize=(15,5))
        gs = gridspec.GridSpec(2, 3, figure=fig)

        study_folder = study_base_folder / f'20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        
        study = si.GroundTruthStudy(study_folder)
        study.run_comparisons(exhaustive_gt=True)

        ax_1 = plt.subplot(gs[0, 0])
        ax_2 = plt.subplot(gs[0, 1:])
        ax_3 = plt.subplot(gs[1, 1:])
        ax_4 = plt.subplot(gs[1, 0])

        for ax in [ax_1, ax_2, ax_3, ax_4]:
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

        ax_2.tick_params(labelbottom=False)
        ax_2.set_xlabel('')

        si.plot_gt_study_run_times(study, ax=ax_1)
        si.plot_gt_study_unit_counts(study, ax=ax_2)
        si.plot_gt_study_performances_averages(study, ax=ax_3)
        si.plot_gt_study_performances_by_template_similarity(study, ax=ax_4)

        plt.tight_layout()

Plot collision recall as function of the lags (Figure 3)

In [2]:
for rate_level in rate_levels:
    for corr_level in corr_levels:
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)

        for rec_name in res[(rate_level, corr_level)].rec_names:
            res[(rate_level, corr_level)].compute_waveforms(rec_name)

        si.plot_study_comparison_collision_by_similarity(res[(rate_level, corr_level)], 
                                                         show_legend=False, ylim=(0.4, 1))
        plt.tight_layout()
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing
Python 3.6.9 (default, Jan 26 2021, 15:33:00) 
[GCC 8.4.0] does not support parallel processing

Plot collision recall as function of the lag and/or cosine similarity (supplementary figures)

In [3]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_range(res[(rate_level, corr_level)], show_legend=show_legend, similarity_range=[0.5, 1], ax=ax, ylim=(0.3, 1))

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('lags (ms)')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')
In [4]:
gs = gridspec.GridSpec(len(rate_levels), len(corr_levels))
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        ax = plt.subplot(gs[i, j])

        if i == 0 and j == 0:
            show_legend = True
        else:
            show_legend = False
        si.plot_study_comparison_collision_by_similarity_ranges(res[(rate_level, corr_level)], show_legend=show_legend, ax=ax)

        ax.set_title(f'Rate {rate_level} Hz, Corr {corr_level}' )
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        if rate_level != rate_levels[-1]:
            ax.tick_params(labelbottom=False)
            ax.set_xlabel('')
        else:
            ax.set_xlabel('similarity')

        if corr_level != corr_levels[0]:
            ax.tick_params(labelleft=False)
            ax.set_ylabel('')
        else:
            ax.set_ylabel('collision accuracy')

Plot average collision recall over multiple conditions, as function of the lags (Figure 5)

In [9]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]

gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:
            data = res[(rate_level, corr_level)].get_mean_over_similarity_range([0.5, 1], sorter_name)
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

lags = res[(rate_level, corr_level)].get_lags()
for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter, label=sorter_name)
    ax.fill_between(lags[:-1] + (lags[1]-lags[0]) / 2, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)

ax.legend()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('lag (ms)')
ax.set_ylabel('collision accuracy')
Out[9]:
Text(0, 0.5, 'collision accuracy')

Plotting the average collision recall over multiple conditions, as function of the similarity

In [5]:
rate_levels = [5,10,15]
corr_levels = [0, 0.1, 0.2]
res = {}
gs = gridspec.GridSpec(1, 2)
ax = plt.subplot(gs[0, 0])
curves = {}
similarity_ranges = np.linspace(-0.4, 1, 8)
for i, rate_level in enumerate(rate_levels):
    for j, corr_level in enumerate(corr_levels):
        
        study_folder = f'../data/study/20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32'
        res[(rate_level, corr_level)] = si.CollisionGTStudy(study_folder)
        res[(rate_level, corr_level)].run_comparisons(exhaustive_gt=True, collision_lag=2, nbins=11)
        res[(rate_level, corr_level)].precompute_scores_by_similarities()
        
        for sorter_name in res[(rate_level, corr_level)].sorter_names:

            all_similarities = res[(rate_level, corr_level)].all_similarities[sorter_name]
            all_recall_scores = res[(rate_level, corr_level)].all_recall_scores[sorter_name]

            order = np.argsort(all_similarities)
            all_similarities = all_similarities[order]
            all_recall_scores = all_recall_scores[order, :]

            mean_recall_scores = []
            std_recall_scores = []
            for k in range(similarity_ranges.size - 1):
                cmin, cmax = similarity_ranges[k], similarity_ranges[k + 1]
                amin, amax = np.searchsorted(all_similarities, [cmin, cmax])
                value = np.mean(all_recall_scores[amin:amax])
                mean_recall_scores += [np.nan_to_num(value)]

            xaxis = np.diff(similarity_ranges)/2 + similarity_ranges[:-1]

            data = mean_recall_scores
            if not sorter_name in curves:
                curves[sorter_name] = [data]
            else:
                curves[sorter_name] += [data]

for sorter_name in res[(rate_level, corr_level)].sorter_names:
    curves[sorter_name] = np.array(curves[sorter_name])
    mean_sorter = curves[sorter_name].mean(0)
    std_sorter = curves[sorter_name].std(0)
    ax.plot(xaxis, mean_sorter, label=sorter_name)
    ax.fill_between(xaxis, mean_sorter-std_sorter,mean_sorter+std_sorter, alpha=0.2)


ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('cosine similarity')
#ax.set_ylabel('collision accuracy')
#ax.set_yticks([])

plt.tight_layout()
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/fromnumeric.py:3373: RuntimeWarning: Mean of empty slice.
  out=out, **kwargs)
/home/cure/.local/lib/python3.6/site-packages/numpy/core/_methods.py:170: RuntimeWarning: invalid value encountered in double_scalars
  ret = ret.dtype.type(ret / rcount)

Collision paper simulated recordings

Simulated recordings overview (figure 1)

This notebook reproduces Figure 1 of the manuscript: "How do spike collisions affect spike sorting performance?"

To run this notebook, you first need to run the generate_recordings.ipynb notebook.

In [1]:
import shutil
import sys
from pathlib import Path

import numpy as np
import scipy.spatial

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import MEArec as mr
import spikeinterface.full as si


my_cmap = plt.get_cmap('winter')
cNorm  = colors.Normalize(vmin=0, vmax=1)
scalarMap = plt.cm.ScalarMappable(norm=cNorm, cmap=my_cmap)
In [ ]:
sys.path.append("../utils")

from generation_utils import generation_params
from study_utils import generate_study
In [ ]:
recordings_folder = Path('../data/recordings/'')
In [2]:
# define some parameters

nb_traces = 10 # for panel I
window_ms = 20 #for CC plots
bin_ms = 0.2 # for CC plots
n_cell = 20 #
lag_time = generation_params['lag_time']*1000
corr_level = 0 # to select the appropriate recording if several (run generation first)
rate_level = 5 # to select the appropriate recording if several (run generation first)
In [8]:
# We use the plotting.py script to ease the creation of figures with several panels. 
figA, axA = plt.subplots()

# We load the file
rec_file = recordings_folder / f'rec0_20cells_5noise_{corr_level}corr_{rate_level}rate_Neuronexus-32.h5'

mearec_object = mr.load_recordings(rec_file)
rec = si.MEArecRecordingExtractor(rec_file)
sorting_gt = si.MEArecSortingExtractor(rec_file)

waveforms_path = Path('.') / 'tmp'
waveforms_path.mkdir(exist_ok=True)

waveforms = si.extract_waveforms(rec, sorting_gt, waveforms_path, ms_before=3, ms_after=3)

original_templates = waveforms.get_all_templates()
snrs = np.array([i for i in si.compute_snrs(waveforms).values()])
rates = np.array([i for i in si.compute_firing_rate(waveforms).values()])


## Plotting the probe layout and the cell positions
si.plot_unit_localization(waveforms, ax=axA)
axA.set_ylabel('y (um)')
axA.set_xlabel('x (um)')
In [9]:
figB, axB = plt.subplots(ncols=3, figsize=(12, 7))

colors = {'#0' : 'k', '#16' : 'r'}

similarities = si.compute_template_similarity(waveforms)

## Plotting example of pair with selected similarity
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[0]], unit_ids=['#16'], unit_colors=colors)
axB[0].set_title('(#0, #16) similarity %02g' %similarities[0, 16])   

colors = {'#0' : 'k', '#10' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[1]], unit_ids=['#10'], unit_colors=colors)    
axB[1].set_title('(#0, #10) similarity %02g' %similarities[0, 10])

colors = {'#0' : 'k', '#1' : 'r'}
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#0'], unit_colors=colors)
si.plot_unit_templates(waveforms, plot_templates=True, axes=[axB[2]], unit_ids=['#1'], unit_colors=colors)    
axB[2].set_title('(#0, #1) similarity %02g' %similarities[0, 1])
figB.tight_layout()
In [11]:
figC, axC = pltsubplotsfigure()

## Plotting the similarity matrix
im = axC.imshow(similarities, cmap='viridis',
                aspect='auto',
                origin='lower',
                interpolation='none',
                extent=(-0.5, n_cell-1+0.5, -0.5, n_cell-1+0.5))
axC.set_xlabel('# cell')
axC.set_ylabel('# cell')
plt.colorbar(im, ax=axC, label='cosine similarity')
Out[11]:
<matplotlib.colorbar.Colorbar at 0x7f6c27a39208>
In [12]:
figDE, axDE = plt.subplots(nrows=2)

centers = np.array([v for v in si.compute_unit_centers_of_mass(waveforms).values()])
real_centers = mearec_object.template_locations[:]

distances = scipy.spatial.distance_matrix(centers, centers)
real_distances =  scipy.spatial.distance_matrix(real_centers, real_centers)

# Plotting the distribution of similarities as function of distance (either real or estimated)
axDE[0].plot(distances.flatten(), similarities.flatten(), '.', label='Center of Mass')
axDE[0].plot(real_distances.flatten(), similarities.flatten(), '.', label='Real position')
axDE[0].legend()
axDE[0].set_xlabel('distances (um)')
axDE[0].set_ylabel('cosine similarity')

x, y = np.histogram(similarities.flatten(), 10)
axD[1].bar(y[1:], x/float(x.sum()), width=y[1]-y[0])
axD[1].set_xlabel('cosine similarity')
axD[1].set_ylabel('probability')
Out[12]:
Text(0, 0.5, 'probability')
In [14]:
## For the CC, you should uncomment the following line, but the figure was assembled
w = si.plot_crosscorrelograms(sorting_gt, ['#%s' %i for i in range(0,3)], 
                              bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
figF = w.figure
Out[14]:
<spikeinterface.widgets.correlograms.CrossCorrelogramsWidget at 0x7f6bae84ae80>
In [16]:
figGH, axGH = plt.subplots(nrows=2)

ccs, lags = si.compute_correlograms(sorting_gt, bin_ms=bin_ms, window_ms=window_ms, symmetrize=True)
ccs = ccs.reshape(n_cell**2, ccs.shape[2])
mask = np.ones(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = False
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

## Plotting the average CC
xaxis = (lags[:-1] - lags[:-1].mean())
axGH[0].plot(xaxis, mean_cc, lw=2, c='r')
axGH[0].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[0].set_xlabel('time (ms)')
axGH[0].set_ylabel('cross correlation')
ymin, ymax = axGH[0].get_ylim()
axGH[0].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[0].plot([lag_time,lag_time],[ymin,ymax],'k--')

mask = np.zeros(n_cell**2).astype(np.bool)
mask[np.arange(0, n_cell**2, n_cell) + np.arange(n_cell)] = True
mean_cc = np.mean(ccs[mask], 0)
std_cc = np.std(ccs[mask], 0)

xaxis = (lags[:-1] - lags[:-1].mean())
axGH[1].plot(xaxis, mean_cc, lw=2, c='r')
axGH[1].fill_between(xaxis, mean_cc-std_cc,mean_cc+std_cc, color='0.5', alpha=0.5)
axGH[1].set_ylabel('auto correlation')
ymin, ymax = axGH[1].get_ylim()
axGH[1].plot([-lag_time,-lag_time],[ymin,ymax],'k--')
axGH[1].plot([lag_time,lag_time],[ymin,ymax],'k--')
Out[16]:
[<matplotlib.lines.Line2D at 0x7f6c702a1358>]
In [18]:
## Plotting timeseries
w = si.plot_timeseries(rec, time_range=(5,5.1), channel_ids=['%s' %i for i in range(1,nb_traces)], color='k')
figI = w.figure
Out[18]:
<spikeinterface.widgets.timeseries.TimeseriesWidget at 0x7f6badf61f60>

Collision paper generate recordings

Generation of the recordings

In this notebook, we will generate all the recordings with MEArec that will be necessary to populate the study and compare the sorters. First, we need to create a function that will, given a dictionary of parameter, generate a single recording. The recording parameters can be defined as follows

In [4]:
import os
import sys
import shutil
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import MEArec as mr
import spikeinterface.full as si
In [8]:
sys.path.append('../utils/')

from corr_spike_trains import CorrelatedSpikeGenerator
In [3]:
generation_params = {
    'probe' : 'Neuronexus-32', #layout of the probe used
    'duration' : 30*60, #total duration of the recording
    'n_cell' : 20, # number of cells that will be injected
    'fs' : 30000., # sampling rate
    'lag_time' : 0.002,  # half refractory period in ms
    'make_plots' : True,
    'generate_recording' : True,
    'noise_level' : 5,
    'templates_seed' : 42,
    'noise_seed' : 42,
    'global_path' : os.path.abspath('../'),
    'study_number' : 0,
    'save_plots' : True,
    'method' : 'brette', # 'poisson' | 'brette'
    'corr_level' : 0,
    'rate_level' : 5, #Hz
    'nb_recordings' : 5
}

With these parameters, we will create 20 neurons, and correlation levels will be generated via the mixture process of [Brette et al, 2009]. The function to generate a single recording is defined as follows. It assumes that you have, in your folder, a file named ../data/templates/templates_{probe}_100.h5 with all the pre-generated templates that will be used by MEArec

In [5]:
def generate_single_recording(params=generation_params):

    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings') 

    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    template_filename = os.path.join(paths['templates'], f'templates_{probe}_100.h5')
    recording_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')
    plot_filename = os.path.join(paths['recordings'], f'rec{study_number}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.pdf')

    spikerate = params['rate_level']
    n_spike_alone = int(spikerate * params['duration'])

    print('Total target rate:', params['rate_level'], "Hz")
    print('Basal rate:', spikerate, "Hz")


    # collision lag range
    lag_sample = int(params['lag_time'] * params['fs'])

    refactory_period = 2 * params['lag_time']

    spiketimes = []

    if params['method'] == 'poisson':
        print('Spike trains generated as independent poisson sources')
        
        for i in range(params['n_cell']):
            
            #~ n = n_spike_alone + n_collision_by_pair * (params['n_cell'] - i - 1)
            n = n_spike_alone
            #~ times = np.random.rand(n_spike_alone) * params['duration']
            times = np.random.rand(n) * params['duration']
            
            times = np.sort(times)
            spiketimes.append(times)

    elif params['method'] == 'brette':
        print('Spike trains generated as compound mixtures')
        C = np.ones((params['n_cell'], params['n_cell']))
        C = params['corr_level'] * np.maximum(C, C.T)
        #np.fill_diagonal(C, 0*np.ones(params['n_cell']))

        rates = rates = params['rate_level']*np.ones(params['n_cell'])

        cor_spk = CorrelatedSpikeGenerator(C, rates, params['n_cell'])
        cor_spk.find_mixture(iter=1e4)
        res = cor_spk.mixture_process(tauc=refactory_period/2, t=params['duration'])
        
        # make neo spiketrains
        for i in range(params['n_cell']):
            #~ print(spiketimes[i])
            mask = res[:, 0] == i
            times = res[mask, 1]
            times = np.sort(times)
            mask = (times > 0) * (times < params['duration'])
            times = times[mask]
            spiketimes.append(times)


    # remove refactory period
    for i in range(params['n_cell']):
        times = spiketimes[i]
        ind, = np.nonzero(np.diff(times) < refactory_period)
        ind += 1
        times = np.delete(times, ind)
        assert np.sum(np.diff(times) < refactory_period) ==0
        spiketimes[i] = times

    # make neo spiketrains
    spiketrains = []
    for i in range(params['n_cell']):
        mask = np.where(spiketimes[i] > 0)
        spiketimes[i] = spiketimes[i][mask] 
        spiketrain = neo.SpikeTrain(spiketimes[i], units='s', t_start=0*pq.s, t_stop=params['duration']*pq.s)
        spiketrain.annotate(cell_type='E')
        spiketrains.append(spiketrain)

    # check with sanity plot here
    if params['make_plots']:
        
        # count number of spike per units
        fig, axs = plt.subplots(2, 2)
        count = [st.size for st in spiketrains]
        ax = axs[0, 0]
        simpleaxis(ax)
        pairs = []
        collision_count_by_pair = []
        collision_count_by_units = np.zeros(n_cell)
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                pairs.append(f'{i}-{j}')
                collision_count_by_pair.append(matching_event.size)
                collision_count_by_units[i] += matching_event.size
                collision_count_by_units[j] += matching_event.size
        ax.plot(np.arange(len(collision_count_by_pair)), collision_count_by_pair)
        ax.set_xticks(np.arange(len(collision_count_by_pair)))
        ax.set_xticklabels(pairs)
        ax.set_ylim(0, max(collision_count_by_pair) * 1.1)
        ax.set_ylabel('# Collisions')
        ax.set_xlabel('Pairs')

        # count number of spike per units
        count_total = np.array([st.size for st in spiketrains])
        count_not_collision = count_total - collision_count_by_units

        ax = axs[1, 0]
        simpleaxis(ax)
        ax.bar(np.arange(n_cell).astype(np.int)+1, count_not_collision, color='g')
        ax.bar(np.arange(n_cell).astype(np.int)+1, collision_count_by_units, bottom =count_not_collision, color='r')
        ax.set_ylabel('# spikes')
        ax.set_xlabel('Cell id')
        ax.legend(('Not colliding', 'Colliding'), loc='best')

        # cross corrlogram
        ax = axs[0, 1]
        simpleaxis(ax)
        counts = []
        for i in range(n_cell):
            for j in range(i+1, n_cell):
                times1 = spiketrains[i].rescale('s').magnitude
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                
                #~ ax = axs[i, j]
                all_lag = matching_event['delta_frame']  / params['fs']
                count, bins  = np.histogram(all_lag, bins=np.arange(-params['lag_time'], params['lag_time'], params['lag_time']/20))
                #~ ax.bar(bins[:-1], count, bins[1] - bins[0])
                ax.plot(1000*bins[:-1], count, bins[1] - bins[0], c='0.5')
                counts += [count]
        counts = np.array(counts)
        counts = np.mean(counts, 0)
        ax.plot(1000*bins[:-1], counts, bins[1] - bins[0], c='r')
        ax.set_xlabel('Lags [ms]')
        ax.set_ylabel('# Collisions')

        ax = axs[1, 1]
        simpleaxis(ax)
        ratios = []
        for i in range(n_cell):
            nb_spikes = len(spiketrains[i])
            nb_collisions = 0
            times1 = spiketrains[i].rescale('s').magnitude
            for j in list(range(0, i)) + list(range(i+1, n_cell)):
                times2 = spiketrains[j].rescale('s').magnitude
                matching_event = make_matching_events((times1*params['fs']).astype('int64'), (times2*params['fs']).astype('int64'), lag_sample)
                nb_collisions += matching_event.size

            if nb_collisions > 0:
                ratios += [nb_spikes / nb_collisions]
            else:
                ratios += [0]

        ax.bar([0], [np.mean(ratios)], yerr=[np.std(ratios)])
        ax.set_ylabel('# spikes / # collisions')
        plt.tight_layout()

        if params['save_plots']:
            plt.savefig(plot_filename)
        else:
            plt.show()
        plt.close()

    if params['generate_recording']:
        spgen = mr.SpikeTrainGenerator(spiketrains=spiketrains)
        rec_params = mr.get_default_recordings_params()
        rec_params['recordings']['fs'] = params['fs']
        rec_params['recordings']['sync_rate'] = None
        rec_params['recordings']['sync_jitter'] = 5
        rec_params['recordings']['noise_level'] = params['noise_level']
        rec_params['recordings']['filter'] = False
        rec_params['spiketrains']['duration'] = params['duration']
        rec_params['spiketrains']['n_exc'] = params['n_cell']
        rec_params['spiketrains']['n_inh'] = 0
        rec_params['recordings']['chunk_duration'] = 10.
        rec_params['templates']['n_overlap_pairs'] = None
        rec_params['templates']['min_dist'] = 0
        rec_params['seeds']['templates'] = params['templates_seed']
        rec_params['seeds']['noise'] = params['noise_seed']
        recgen = mr.gen_recordings(params=rec_params, spgen=spgen, templates=template_filename, verbose=True)
        mr.save_recording_generator(recgen, filename=recording_filename)

Once this function is created, we can create an additional function that will generate several recordings, with different suffix/seeds:

In [6]:
def generate_recordings(params=generation_params):
    for i in range(params['nb_recordings']):
        generation_params['study_number'] = i
        generation_params['templates_seed'] = i
        generation_params['noise_seed'] = i
        generate_single_recording(generation_params)

And now, we have all the required tools to create our recordings. By default, they will all be saved in the folder ../recordings/

In [7]:
## Provide the different rate and correlations levels you want to generate
rate_levels = [5, 10, 15]
corr_levels = [0, 0.1, 0.2]
generation_params['nb_recordings'] = 5 #Number of recordings per conditions
In [ ]:
result = {}

for rate_level in rate_levels:
    for corr_level in corr_levels:

        generation_params['rate_level'] = rate_level
        generation_params['corr_level'] = corr_level
        generate_recordings(generation_params)

Generation of the study objects

Since the recordings have been generated, we now need to create Study objects for spikeinterface, and run the sorters on all these recordings. Be careful that by default, this can create quite a large amount of data, if you have numerous rate/correlation levels and/or number of recordings and/or sorters. First, we need to tell spikeinterface how to find the sorters

In [11]:
ironclust_path = '/media/cure/Secondary/pierre/softwares/ironclust'
kilosort1_path = '/media/cure/Secondary/pierre/softwares/Kilosort-1.0'
kilosort2_path = '/media/cure/Secondary/pierre/softwares/Kilosort-2.0'
kilosort3_path = '/media/cure/Secondary/pierre/softwares/Kilosort-3.0'
hdsort_path = '/media/cure/Secondary/pierre/softwares/HDsort'
os.environ["KILOSORT_PATH"] = kilosort1_path
os.environ["KILOSORT2_PATH"] = kilosort2_path
os.environ["KILOSORT3_PATH"] = kilosort3_path
os.environ['IRONCLUST_PATH'] = ironclust_path
os.environ['HDSORT_PATH'] = hdsort_path

And then we need to create a function that will, given a list of recordings, create a study and run all the sorters

In [13]:
def generate_study(params, keep_data=True):
    paths = {}
    paths['basedir'] = params['global_path']
    paths['data'] = None

    if paths['data'] == None:
        paths['data'] = os.path.join(paths['basedir'], 'data')

    paths['templates'] =  os.path.join(paths['data'], 'templates')
    paths['recordings'] = os.path.join(paths['data'], 'recordings')
    paths['study'] = os.path.join(paths['data'], 'study')
    
    for i in paths.values():
        if not os.path.exists(i):
            os.makedirs(i)

    probe = params['probe']
    n_cell = params['n_cell']
    noise_level = params['noise_level']
    study_number = params['study_number']
    corr_level = params['corr_level']
    rate_level = params['rate_level']

    paths['mearec_filename'] = []

    study_folder = os.path.join(paths['study'], f'{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}')
    study_folder = Path(study_folder)

    if params['reset_study'] and os.path.exists(study_folder):
        shutil.rmtree(study_folder)

    print('Availables sorters:')
    si.print_sorter_versions()

    gt_dict = {}

    if not os.path.exists(study_folder):

        for i in range(params['nb_recordings']):
            paths['mearec_filename'] += [os.path.join(paths['recordings'], f'rec{i}_{n_cell}cells_{noise_level}noise_{corr_level}corr_{rate_level}rate_{probe}.h5')]

        print('Availables recordings:')
        print(paths['mearec_filename'])

        
        for count, file in enumerate(paths['mearec_filename']):
            rec  = si.MEArecRecordingExtractor(file)
            sorting_gt = si.MEArecSortingExtractor(file)
            gt_dict['rec%d' %count] = (rec, sorting_gt)

        study = si.GroundTruthStudy.create(study_folder, gt_dict, n_jobs=-1, chunk_memory='1G', progress_bar=True)
        study.run_sorters(params['sorter_list'], verbose=False, docker_images=params['docker_images'])
        print("Study created!")
    else:
        study = si.GroundTruthStudy(study_folder)
        if params['relaunch'] == 'all':
            if_exist = 'overwrite'
        elif params['relaunch'] == 'some':
            if_exist = 'keep'

        if params['relaunch'] in ['all', 'some']:
            study.run_sorters(params['sorter_list'], verbose=False, mode_if_folder_exists=if_exist, docker_images=params['docker_images'])
            print("Study loaded!")

    study.copy_sortings()

    if not keep_data:

        for sorter in params['sorter_list']:

            for rec in ['rec%d' %i for i in range(params['nb_recordings'])]:
                sorter_path = os.path.join(study_folder, 'sorter_folders', rec, sorter)
                if os.path.exists(sorter_path):
                    for f in os.listdir(sorter_path):
                        if f != 'spikeinterface_log.json':
                            full_file = os.path.join(sorter_path, f)
                            try:
                                if os.path.isdir(full_file):
                                    shutil.rmtree(full_file)
                                else:
                                    os.remove(full_file)
                            except Exception:
                                pass
        for file in paths['mearec_filename']:
            os.remove(file)

    return study

This function will take a dictionary of inputs (the same as for generating the recordings), and looping over all the possible recordings for a given condition (probe, rate, correlation levels) it will create a study in the path ../study/, running all the sorters on the recordings. This can take a lot of time, depending on the number of recordings/sorters. Note also that by default, the original recorindgs generated by MEArec are kept, and thus duplicated in the study folder. If you want to delete the original recordings (they are not needed for further analysis) you can set keep_data=False

In [14]:
study_params = generation_params.copy()
study_params['sorter_list'] = ['yass', 'kilosort', 'kilosort2', 'kilosort3', 'spykingcircus', 'tridesclous', 'ironclust', 'herdingspikes', 'hdsort']
study_params['docker_images'] = {'yass' : 'spikeinterface/yass-base:2.0.0'} #If some sorters are installed via docker
study_params['relaunch'] = 'all' #If you want to relaunch the sorters. 
study_params['reset_study'] = False #If you want to reset the study (delete everything)
In [ ]:
all_studies = {}
for rate_level in rate_levels:
    for corr_level in corr_levels:

        study_params['rate_level'] = rate_level
        study_params['corr_level'] = corr_level
        all_studies[corr_level, rate_level] = generate_study(study_params)

And this is it! Now you should have several studies, each of them with several recordings that have be analyzed by several sorters, in a structured manner (as function of rate/correlations levels)